{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c285c1c3-9d4c-4b7b-b8a7-95e80594589c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set working directory to the folder 'electra' which is located in the replication folder\n",
    "# %cd electra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cec5963a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/TobiasWidmann/Documents/EUI - European University Institute/Papers/Emotional Dictionary/Political Analysis/Replication 2/replication_folder/electra\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import json\n",
    "import pickle\n",
    "import subprocess\n",
    "import time\n",
    "\n",
    "import datasets\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import transformers\n",
    "from datasets import Dataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from tqdm.notebook import tqdm\n",
    "from transformers import (\n",
    "    AutoModel,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from transformers.modeling_outputs import SequenceClassifierOutput\n",
    "from transformers.trainer_callback import EarlyStoppingCallback\n",
    "\n",
    "import helper.training as tr\n",
    "\n",
    "pd.set_option(\"display.precision\", 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c1b9089a-d13d-4a5e-bdcd-3bc213f1a47e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#import time\n",
    "#start_time = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64879990",
   "metadata": {},
   "source": [
    "### Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "57287633",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"german-nlp-group/electra-base-german-uncased\"\n",
    "DIR_OUTPT = \"./results\"\n",
    "DIR_LOG = \"./logs\"\n",
    "DIR_TRAINED_MODEL = \"./models/final_replication\"\n",
    "SIZE_VALIDATION_SET = 0.1\n",
    "SEED = 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f20580a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd00f192",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cd892e2",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9d493997",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of training set:\t 8017\n",
      "Size of validation set:\t 891\n",
      "Size of test set:\t 990\n"
     ]
    }
   ],
   "source": [
    "columns_a = [\"Text\", \"type\", \"sentences\"]\n",
    "emotions = [\n",
    "    \"hf_anger\",\n",
    "    \"hf_fear\",\n",
    "    \"hf_disgust\",\n",
    "    \"hf_sadness\",\n",
    "    \"hf_joy\",\n",
    "    \"hf_enthusiasm\",\n",
    "    \"hf_pride\",\n",
    "    \"hf_hope\",\n",
    "]\n",
    "\n",
    "df_train_validation = pd.read_pickle(\"./data/labeled_training_data_df.pkl\")\n",
    "df_train_validation = df_train_validation[columns_a + emotions]\n",
    "df_train_validation = df_train_validation.astype(\n",
    "    {\n",
    "        \"hf_anger\": int,\n",
    "        \"hf_fear\": int,\n",
    "        \"hf_disgust\": int,\n",
    "        \"hf_sadness\": int,\n",
    "        \"hf_joy\": int,\n",
    "        \"hf_enthusiasm\": int,\n",
    "        \"hf_pride\": int,\n",
    "        \"hf_hope\": int,\n",
    "        \"type\": str,\n",
    "    }\n",
    ")\n",
    "df_train_validation[\"list\"] = df_train_validation.apply(\n",
    "    lambda x: [\n",
    "        x[\"hf_anger\"],\n",
    "        x[\"hf_fear\"],\n",
    "        x[\"hf_disgust\"],\n",
    "        x[\"hf_sadness\"],\n",
    "        x[\"hf_joy\"],\n",
    "        x[\"hf_enthusiasm\"],\n",
    "        x[\"hf_pride\"],\n",
    "        x[\"hf_hope\"],\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "df_test = pd.read_pickle(\"./data/labeled_test_data_df.pkl\")\n",
    "# df_train_validation = df_train_validation.sample(n=1000)\n",
    "df_test = df_test[columns_a + emotions]\n",
    "df_test = df_test.astype(\n",
    "    {\n",
    "        \"hf_anger\": int,\n",
    "        \"hf_fear\": int,\n",
    "        \"hf_disgust\": int,\n",
    "        \"hf_sadness\": int,\n",
    "        \"hf_joy\": int,\n",
    "        \"hf_enthusiasm\": int,\n",
    "        \"hf_pride\": int,\n",
    "        \"hf_hope\": int,\n",
    "        \"type\": str,\n",
    "    }\n",
    ")\n",
    "df_test[\"list\"] = df_test.apply(\n",
    "    lambda x: [\n",
    "        x[\"hf_anger\"],\n",
    "        x[\"hf_fear\"],\n",
    "        x[\"hf_disgust\"],\n",
    "        x[\"hf_sadness\"],\n",
    "        x[\"hf_joy\"],\n",
    "        x[\"hf_enthusiasm\"],\n",
    "        x[\"hf_pride\"],\n",
    "        x[\"hf_hope\"],\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "\n",
    "df_train, df_validation = train_test_split(\n",
    "    df_train_validation, test_size=0.1, random_state=SEED\n",
    ")\n",
    "\n",
    "print(\"Size of training set:\\t\", len(df_train))\n",
    "print(\"Size of validation set:\\t\", len(df_validation))\n",
    "print(\"Size of test set:\\t\", len(df_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4db9249",
   "metadata": {},
   "source": [
    "### Convert to Dataset format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "358c13bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "773c5223c91842fc95cffb7bb9070448",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8e65d725fb84d20a17f09b3a0c7cd15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset_train = Dataset.from_pandas(df_train)\n",
    "dataset_validation = Dataset.from_pandas(df_validation)\n",
    "dataset_test = Dataset.from_pandas(df_test)\n",
    "dataset_test_fb = dataset_test.filter(lambda x: x[\"type\"] == \"fb_sent\")\n",
    "dataset_test_ps = dataset_test.filter(lambda x: x[\"type\"] == \"ps_sent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b23e8dc",
   "metadata": {},
   "source": [
    "### Tokenize dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "04fde3eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "abee61cc44ec4c9a8f613796acf62bd0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/467 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97d337c5baed42d897903e6222b0a747",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/424M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at german-nlp-group/electra-base-german-uncased were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight']\n",
      "- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at german-nlp-group/electra-base-german-uncased and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8d2cbb1b1c54812b0cb419c9dd8873d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/103 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cce813c274c549409b26b6e76feb2884",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/269k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8cbedf8c8fed4ba9bc7f44b9c4eecc3c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e192509641f4d42bad6952922e3fb53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# load model and tokenizer\n",
    "model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "# preprocess data\n",
    "field_text = \"sentences\"\n",
    "field_label = \"list\"\n",
    "\n",
    "dataset_train = Dataset.from_pandas(df_train)\n",
    "dataset_validation = Dataset.from_pandas(df_validation)\n",
    "dataset_test = Dataset.from_pandas(df_test)\n",
    "dataset_test_fb = dataset_test.filter(lambda x: x[\"type\"] == \"fb_sent\")\n",
    "dataset_test_ps = dataset_test.filter(lambda x: x[\"type\"] == \"ps_sent\")\n",
    "\n",
    "# tokenize data\n",
    "train_encodings = tokenizer(dataset_train[field_text], truncation=True, padding=True)\n",
    "val_encodings = tokenizer(dataset_validation[field_text], truncation=True, padding=True)\n",
    "test_encodings = tokenizer(dataset_test[field_text], truncation=True, padding=True)\n",
    "test_fb_encodings = tokenizer(\n",
    "    dataset_test_fb[field_text], truncation=True, padding=True\n",
    ")\n",
    "test_ps_encodings = tokenizer(\n",
    "    dataset_test_ps[field_text], truncation=True, padding=True\n",
    ")\n",
    "\n",
    "train_dataset = tr.EmotionDataset(train_encodings, dataset_train[field_label])\n",
    "val_dataset = tr.EmotionDataset(val_encodings, dataset_validation[field_label])\n",
    "test_dataset = tr.EmotionDataset(test_encodings, dataset_test[field_label])\n",
    "test_fb_dataset = tr.EmotionDataset(test_fb_encodings, dataset_test_fb[field_label])\n",
    "test_ps_dataset = tr.EmotionDataset(test_ps_encodings, dataset_test_ps[field_label])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3861e15b",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4cb83d04",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 8017\n",
      "  Num Epochs = 4\n",
      "  Instantaneous batch size per device = 32\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 32\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 1004\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1004' max='1004' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1004/1004 11:30, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Loss</th>\n",
       "      <th>Accuracy Thresh</th>\n",
       "      <th>Runtime</th>\n",
       "      <th>Samples Per Second</th>\n",
       "      <th>Steps Per Second</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>7.730285</td>\n",
       "      <td>0.360063</td>\n",
       "      <td>0.840909</td>\n",
       "      <td>4.195700</td>\n",
       "      <td>235.956000</td>\n",
       "      <td>7.389000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.395900</td>\n",
       "      <td>5.912909</td>\n",
       "      <td>0.311027</td>\n",
       "      <td>0.863889</td>\n",
       "      <td>4.397400</td>\n",
       "      <td>225.132000</td>\n",
       "      <td>7.050000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.395900</td>\n",
       "      <td>5.549194</td>\n",
       "      <td>0.313157</td>\n",
       "      <td>0.865404</td>\n",
       "      <td>4.380600</td>\n",
       "      <td>225.996000</td>\n",
       "      <td>7.077000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.236200</td>\n",
       "      <td>5.504445</td>\n",
       "      <td>0.314996</td>\n",
       "      <td>0.867677</td>\n",
       "      <td>4.399700</td>\n",
       "      <td>225.013000</td>\n",
       "      <td>7.046000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n",
      "Saving model checkpoint to ./results/checkpoint-251\n",
      "Configuration saved in ./results/checkpoint-251/config.json\n",
      "Model weights saved in ./results/checkpoint-251/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n",
      "Saving model checkpoint to ./results/checkpoint-502\n",
      "Configuration saved in ./results/checkpoint-502/config.json\n",
      "Model weights saved in ./results/checkpoint-502/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n",
      "Saving model checkpoint to ./results/checkpoint-753\n",
      "Configuration saved in ./results/checkpoint-753/config.json\n",
      "Model weights saved in ./results/checkpoint-753/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n",
      "Saving model checkpoint to ./results/checkpoint-1004\n",
      "Configuration saved in ./results/checkpoint-1004/config.json\n",
      "Model weights saved in ./results/checkpoint-1004/pytorch_model.bin\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n",
      "Loading best model from ./results/checkpoint-1004 (score: 5.504445176492018).\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='94' max='31' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [31/31 00:13]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Configuration saved in ./models/final2/german-nlp-group/electra-base-german-uncased/config.json\n",
      "Model weights saved in ./models/final2/german-nlp-group/electra-base-german-uncased/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=DIR_OUTPT,  # output directory\n",
    "    num_train_epochs=4,  # total # of training epochs\n",
    "    per_device_train_batch_size=32,  # batch size per device during training\n",
    "    per_device_eval_batch_size=32,  # batch size for evaluation\n",
    "    warmup_steps=250,  # number of warmup steps for learning rate scheduler\n",
    "    weight_decay=0.01,  # strength of weight decay\n",
    "    logging_dir=DIR_LOG,  # directory for storing logs\n",
    "    seed=SEED,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"f1_loss\",\n",
    "    greater_is_better=False,\n",
    "    run_name=MODEL_NAME,\n",
    ")\n",
    "\n",
    "trainer = tr.MultilabelTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=test_dataset,\n",
    "    compute_metrics=tr.compute_metrics,\n",
    ")\n",
    "\n",
    "_ = trainer.train()\n",
    "trainer.evaluate()\n",
    "\n",
    "trainer.model.save_pretrained(f\"{DIR_TRAINED_MODEL}/{MODEL_NAME}/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eed7a1e0",
   "metadata": {},
   "source": [
    "## Evaluate model on test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bcf797aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Prediction *****\n",
      "  Num examples = 990\n",
      "  Batch size = 32\n",
      "***** Running Prediction *****\n",
      "  Num examples = 505\n",
      "  Batch size = 32\n",
      "***** Running Prediction *****\n",
      "  Num examples = 485\n",
      "  Batch size = 32\n"
     ]
    }
   ],
   "source": [
    "results_all = trainer.predict(test_dataset)\n",
    "results_fb = trainer.predict(test_fb_dataset)\n",
    "results_ps = trainer.predict(test_ps_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a68268",
   "metadata": {},
   "source": [
    "### Complete test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d4297ad9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>emotion</th>\n",
       "      <th>Recall</th>\n",
       "      <th>Precision</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>hf_anger</td>\n",
       "      <td>0.831</td>\n",
       "      <td>0.868</td>\n",
       "      <td>0.849</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>hf_fear</td>\n",
       "      <td>0.603</td>\n",
       "      <td>0.675</td>\n",
       "      <td>0.637</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>hf_disgust</td>\n",
       "      <td>0.523</td>\n",
       "      <td>0.634</td>\n",
       "      <td>0.573</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hf_sadness</td>\n",
       "      <td>0.527</td>\n",
       "      <td>0.658</td>\n",
       "      <td>0.586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>hf_joy</td>\n",
       "      <td>0.608</td>\n",
       "      <td>0.680</td>\n",
       "      <td>0.642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>hf_enthusiasm</td>\n",
       "      <td>0.623</td>\n",
       "      <td>0.662</td>\n",
       "      <td>0.642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>hf_pride</td>\n",
       "      <td>0.563</td>\n",
       "      <td>0.618</td>\n",
       "      <td>0.589</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>hf_hope</td>\n",
       "      <td>0.784</td>\n",
       "      <td>0.683</td>\n",
       "      <td>0.730</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         emotion  Recall  Precision     F1\n",
       "0       hf_anger   0.831      0.868  0.849\n",
       "1        hf_fear   0.603      0.675  0.637\n",
       "2     hf_disgust   0.523      0.634  0.573\n",
       "3     hf_sadness   0.527      0.658  0.586\n",
       "4         hf_joy   0.608      0.680  0.642\n",
       "5  hf_enthusiasm   0.623      0.662  0.642\n",
       "6       hf_pride   0.563      0.618  0.589\n",
       "7        hf_hope   0.784      0.683  0.730"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = dict({\"emotion\": emotions})\n",
    "to_add = {\n",
    "    \"Recall\": tr.compute_fine_metrics2(results_all, emotions)[\"recall\"],\n",
    "    \"Precision\": tr.compute_fine_metrics2(results_all, emotions)[\"precision\"],\n",
    "    \"F1\": tr.compute_fine_metrics2(results_all, emotions)[\"f1\"],\n",
    "}\n",
    "df = pd.DataFrame.from_dict(dict(data, **to_add))\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dda1c7f",
   "metadata": {},
   "source": [
    "### Facebook test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1f7a09fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>emotion</th>\n",
       "      <th>Recall</th>\n",
       "      <th>Precision</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>hf_anger</td>\n",
       "      <td>0.833</td>\n",
       "      <td>0.861</td>\n",
       "      <td>0.847</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>hf_fear</td>\n",
       "      <td>0.685</td>\n",
       "      <td>0.698</td>\n",
       "      <td>0.692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>hf_disgust</td>\n",
       "      <td>0.559</td>\n",
       "      <td>0.635</td>\n",
       "      <td>0.595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hf_sadness</td>\n",
       "      <td>0.593</td>\n",
       "      <td>0.698</td>\n",
       "      <td>0.641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>hf_joy</td>\n",
       "      <td>0.675</td>\n",
       "      <td>0.683</td>\n",
       "      <td>0.679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>hf_enthusiasm</td>\n",
       "      <td>0.644</td>\n",
       "      <td>0.702</td>\n",
       "      <td>0.672</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>hf_pride</td>\n",
       "      <td>0.593</td>\n",
       "      <td>0.659</td>\n",
       "      <td>0.624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>hf_hope</td>\n",
       "      <td>0.834</td>\n",
       "      <td>0.738</td>\n",
       "      <td>0.783</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         emotion  Recall  Precision     F1\n",
       "0       hf_anger   0.833      0.861  0.847\n",
       "1        hf_fear   0.685      0.698  0.692\n",
       "2     hf_disgust   0.559      0.635  0.595\n",
       "3     hf_sadness   0.593      0.698  0.641\n",
       "4         hf_joy   0.675      0.683  0.679\n",
       "5  hf_enthusiasm   0.644      0.702  0.672\n",
       "6       hf_pride   0.593      0.659  0.624\n",
       "7        hf_hope   0.834      0.738  0.783"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = dict({\"emotion\": emotions})\n",
    "to_add = {\n",
    "    \"Recall\": tr.compute_fine_metrics2(results_fb, emotions)[\"recall\"],\n",
    "    \"Precision\": tr.compute_fine_metrics2(results_fb, emotions)[\"precision\"],\n",
    "    \"F1\": tr.compute_fine_metrics2(results_fb, emotions)[\"f1\"],\n",
    "}\n",
    "df = pd.DataFrame.from_dict(dict(data, **to_add))\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ad44b1f",
   "metadata": {},
   "source": [
    "### Parliament speech test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "895ffee0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>emotion</th>\n",
       "      <th>Recall</th>\n",
       "      <th>Precision</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>hf_anger</td>\n",
       "      <td>0.828</td>\n",
       "      <td>0.876</td>\n",
       "      <td>0.851</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>hf_fear</td>\n",
       "      <td>0.494</td>\n",
       "      <td>0.635</td>\n",
       "      <td>0.556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>hf_disgust</td>\n",
       "      <td>0.444</td>\n",
       "      <td>0.632</td>\n",
       "      <td>0.522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hf_sadness</td>\n",
       "      <td>0.443</td>\n",
       "      <td>0.600</td>\n",
       "      <td>0.510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>hf_joy</td>\n",
       "      <td>0.517</td>\n",
       "      <td>0.674</td>\n",
       "      <td>0.585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>hf_enthusiasm</td>\n",
       "      <td>0.591</td>\n",
       "      <td>0.605</td>\n",
       "      <td>0.598</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>hf_pride</td>\n",
       "      <td>0.522</td>\n",
       "      <td>0.565</td>\n",
       "      <td>0.543</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>hf_hope</td>\n",
       "      <td>0.721</td>\n",
       "      <td>0.616</td>\n",
       "      <td>0.664</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         emotion  Recall  Precision     F1\n",
       "0       hf_anger   0.828      0.876  0.851\n",
       "1        hf_fear   0.494      0.635  0.556\n",
       "2     hf_disgust   0.444      0.632  0.522\n",
       "3     hf_sadness   0.443      0.600  0.510\n",
       "4         hf_joy   0.517      0.674  0.585\n",
       "5  hf_enthusiasm   0.591      0.605  0.598\n",
       "6       hf_pride   0.522      0.565  0.543\n",
       "7        hf_hope   0.721      0.616  0.664"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = dict({\"emotion\": emotions})\n",
    "to_add = {\n",
    "    \"Recall\": tr.compute_fine_metrics2(results_ps, emotions)[\"recall\"],\n",
    "    \"Precision\": tr.compute_fine_metrics2(results_ps, emotions)[\"precision\"],\n",
    "    \"F1\": tr.compute_fine_metrics2(results_ps, emotions)[\"f1\"],\n",
    "}\n",
    "df = pd.DataFrame.from_dict(dict(data, **to_add))\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ebd155-8e6c-4e4f-9ca2-1fecf4467154",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "22585bf6-72a7-4c2a-b4f6-28b8096b3a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "#end = time.time()\n",
    "#print((end-start_time)/60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455cdcaf-959d-4990-8466-e7556225a432",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25dabd8-2829-4569-a614-1629f7e910fa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "common-cu110.m87",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m87"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
